<!DOCTYPE html><html lang="en" data-theme="light"><head><meta charset="UTF-8"><meta http-equiv="X-UA-Compatible" content="IE=edge"><meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0"><title>Pytorch框架深度学习入门基础 | 1025's Blog</title><meta name="author" content="1025"><meta name="copyright" content="1025"><meta name="format-detection" content="telephone=no"><meta name="theme-color" content="#ffffff"><meta name="description" content="深度学习入门，让你的机器产生思维吧！">
<meta property="og:type" content="article">
<meta property="og:title" content="Pytorch框架深度学习入门基础">
<meta property="og:url" content="https://lwkblog.gitee.io/2023/03/14/Pytorch%E6%A1%86%E6%9E%B6%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8%E5%9F%BA%E7%A1%80/index.html">
<meta property="og:site_name" content="1025&#39;s Blog">
<meta property="og:description" content="深度学习入门，让你的机器产生思维吧！">
<meta property="og:locale" content="en_US">
<meta property="og:image" content="https://pic1.zhimg.com/v2-5ea151d43c0ebb13546985d225ac256a_1200x500.jpg">
<meta property="article:published_time" content="2023-03-14T13:06:41.000Z">
<meta property="article:modified_time" content="2023-03-15T10:44:35.099Z">
<meta property="article:author" content="1025">
<meta property="article:tag" content="���ͣ�blog">
<meta name="twitter:card" content="summary">
<meta name="twitter:image" content="https://pic1.zhimg.com/v2-5ea151d43c0ebb13546985d225ac256a_1200x500.jpg"><link rel="shortcut icon" href="https://www.logosc.cn/oss/icons/2022/05/27/nOPDe9rUUvMC63B.png"><link rel="canonical" href="https://lwkblog.gitee.io/2023/03/14/Pytorch%E6%A1%86%E6%9E%B6%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8%E5%9F%BA%E7%A1%80/index.html"><link rel="preconnect" href="//cdn.jsdelivr.net"/><link rel="preconnect" href="//fonts.googleapis.com" crossorigin=""/><link rel="preconnect" href="//busuanzi.ibruce.info"/><link rel="stylesheet" href="/css/index.css"><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@fortawesome/fontawesome-free/css/all.min.css" media="print" onload="this.media='all'"><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@fancyapps/ui/dist/fancybox.min.css" media="print" onload="this.media='all'"><link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Titillium+Web&amp;display=swap" media="print" onload="this.media='all'"><script>const GLOBAL_CONFIG = { 
  root: '/',
  algolia: undefined,
  localSearch: {"path":"/search.xml","preload":true,"languages":{"hits_empty":"We didn't find any results for the search: ${query}"}},
  translate: undefined,
  noticeOutdate: undefined,
  highlight: {"plugin":"highlighjs","highlightCopy":true,"highlightLang":true,"highlightHeightLimit":false},
  copy: {
    success: 'Copy successfully',
    error: 'Copy error',
    noSupport: 'The browser does not support'
  },
  relativeDate: {
    homepage: false,
    post: false
  },
  runtime: 'days',
  date_suffix: {
    just: 'Just',
    min: 'minutes ago',
    hour: 'hours ago',
    day: 'days ago',
    month: 'months ago'
  },
  copyright: {"limitCount":50,"languages":{"author":"Author: 1025","link":"Link: ","source":"Source: 1025's Blog","info":"Copyright is owned by the author. For commercial reprints, please contact the author for authorization. For non-commercial reprints, please indicate the source."}},
  lightbox: 'fancybox',
  Snackbar: undefined,
  source: {
    justifiedGallery: {
      js: 'https://cdn.jsdelivr.net/npm/flickr-justified-gallery/dist/fjGallery.min.js',
      css: 'https://cdn.jsdelivr.net/npm/flickr-justified-gallery/dist/fjGallery.min.css'
    }
  },
  isPhotoFigcaption: true,
  islazyload: false,
  isAnchor: false,
  percent: {
    toc: true,
    rightside: false,
  }
}</script><script id="config-diff">var GLOBAL_CONFIG_SITE = {
  title: 'Pytorch框架深度学习入门基础',
  isPost: true,
  isHome: false,
  isHighlightShrink: false,
  isToc: true,
  postUpdate: '2023-03-15 18:44:35'
}</script><noscript><style type="text/css">
  #nav {
    opacity: 1
  }
  .justified-gallery img {
    opacity: 1
  }

  #recent-posts time,
  #post-meta time {
    display: inline !important
  }
</style></noscript><script>(win=>{
    win.saveToLocal = {
      set: function setWithExpiry(key, value, ttl) {
        if (ttl === 0) return
        const now = new Date()
        const expiryDay = ttl * 86400000
        const item = {
          value: value,
          expiry: now.getTime() + expiryDay,
        }
        localStorage.setItem(key, JSON.stringify(item))
      },

      get: function getWithExpiry(key) {
        const itemStr = localStorage.getItem(key)

        if (!itemStr) {
          return undefined
        }
        const item = JSON.parse(itemStr)
        const now = new Date()

        if (now.getTime() > item.expiry) {
          localStorage.removeItem(key)
          return undefined
        }
        return item.value
      }
    }
  
    win.getScript = url => new Promise((resolve, reject) => {
      const script = document.createElement('script')
      script.src = url
      script.async = true
      script.onerror = reject
      script.onload = script.onreadystatechange = function() {
        const loadState = this.readyState
        if (loadState && loadState !== 'loaded' && loadState !== 'complete') return
        script.onload = script.onreadystatechange = null
        resolve()
      }
      document.head.appendChild(script)
    })
  
    win.getCSS = url => new Promise((resolve, reject) => {
      const link = document.createElement('link')
      link.rel = 'stylesheet'
      link.href = url
      link.onload = () => resolve()
      link.onerror = () => reject()
      document.head.appendChild(link)
    })
  
      win.activateDarkMode = function () {
        document.documentElement.setAttribute('data-theme', 'dark')
        if (document.querySelector('meta[name="theme-color"]') !== null) {
          document.querySelector('meta[name="theme-color"]').setAttribute('content', '#0d0d0d')
        }
      }
      win.activateLightMode = function () {
        document.documentElement.setAttribute('data-theme', 'light')
        if (document.querySelector('meta[name="theme-color"]') !== null) {
          document.querySelector('meta[name="theme-color"]').setAttribute('content', '#ffffff')
        }
      }
      const t = saveToLocal.get('theme')
    
          if (t === 'dark') activateDarkMode()
          else if (t === 'light') activateLightMode()
        
      const asideStatus = saveToLocal.get('aside-status')
      if (asideStatus !== undefined) {
        if (asideStatus === 'hide') {
          document.documentElement.classList.add('hide-aside')
        } else {
          document.documentElement.classList.remove('hide-aside')
        }
      }
    
    const detectApple = () => {
      if(/iPad|iPhone|iPod|Macintosh/.test(navigator.userAgent)){
        document.documentElement.classList.add('apple')
      }
    }
    detectApple()
    })(window)</script><meta name="generator" content="Hexo 5.4.2"></head><body><div id="loading-box"><div class="loading-left-bg"></div><div class="loading-right-bg"></div><div class="spinner-box"><div class="configure-border-1"><div class="configure-core"></div></div><div class="configure-border-2"><div class="configure-core"></div></div><div class="loading-word">Loading...</div></div></div><script>const preloader = {
  endLoading: () => {
    document.body.style.overflow = 'auto';
    document.getElementById('loading-box').classList.add("loaded")
  },
  initLoading: () => {
    document.body.style.overflow = '';
    document.getElementById('loading-box').classList.remove("loaded")

  }
}
window.addEventListener('load',()=> { preloader.endLoading() })

if (true) {
  document.addEventListener('pjax:send', () => { preloader.initLoading() })
  document.addEventListener('pjax:complete', () => { preloader.endLoading() })
}</script><div id="web_bg"></div><div id="sidebar"><div id="menu-mask"></div><div id="sidebar-menus"><div class="avatar-img is-center"><img src="https://pic1.zhimg.com/50/v2-530bac941983a6aae866ab34b6290be2_720w.jpg?source=1940ef5c" onerror="onerror=null;src='/img/friend_404.gif'" alt="avatar"/></div><div class="sidebar-site-data site-data is-center"><a href="/archives/"><div class="headline">Articles</div><div class="length-num">12</div></a><a href="/tags/"><div class="headline">Tags</div><div class="length-num">0</div></a><a href="/categories/"><div class="headline">Categories</div><div class="length-num">5</div></a></div><hr/><div class="menus_items"><div class="menus_item"><a class="site-page" href="/"><i class="fa-fw fas fa-home"></i><span> Home</span></a></div><div class="menus_item"><a class="site-page" href="/archives/"><i class="fa-fw fas fa-archive"></i><span> Archives</span></a></div><div class="menus_item"><a class="site-page" href="/categories/"><i class="fa-fw fas fa-folder-open"></i><span> Categories</span></a></div><div class="menus_item"><a class="site-page group hide" href="javascript:void(0);"><i class="fa-fw fas fa-list"></i><span> List</span><i class="fas fa-chevron-down"></i></a><ul class="menus_item_child"><li><a class="site-page child" href="/gallery/"><i class="fa-fw fas fa-images"></i><span> Gallery</span></a></li><li><a class="site-page child" href="/music/"><i class="fa-fw fas fa-music"></i><span> Music</span></a></li><li><a class="site-page child" href="/movie/"><i class="fa-fw fas fa-video"></i><span> Movie</span></a></li></ul></div><div class="menus_item"><a class="site-page" href="/about/"><i class="fa-fw fas fa-heart"></i><span> About</span></a></div></div></div></div><div class="post" id="body-wrap"><header class="post-bg" id="page-header" style="background-image: url('https://pic1.zhimg.com/v2-5ea151d43c0ebb13546985d225ac256a_1200x500.jpg')"><nav id="nav"><span id="blog-info"><a href="/" title="1025's Blog"><img class="site-icon" src="/img/friend_404.gif"/><span class="site-name">1025's Blog</span></a></span><div id="menus"><div id="search-button"><a class="site-page social-icon search" href="javascript:void(0);"><i class="fas fa-search fa-fw"></i><span> Search</span></a></div><div class="menus_items"><div class="menus_item"><a class="site-page" href="/"><i class="fa-fw fas fa-home"></i><span> Home</span></a></div><div class="menus_item"><a class="site-page" href="/archives/"><i class="fa-fw fas fa-archive"></i><span> Archives</span></a></div><div class="menus_item"><a class="site-page" href="/categories/"><i class="fa-fw fas fa-folder-open"></i><span> Categories</span></a></div><div class="menus_item"><a class="site-page group hide" href="javascript:void(0);"><i class="fa-fw fas fa-list"></i><span> List</span><i class="fas fa-chevron-down"></i></a><ul class="menus_item_child"><li><a class="site-page child" href="/gallery/"><i class="fa-fw fas fa-images"></i><span> Gallery</span></a></li><li><a class="site-page child" href="/music/"><i class="fa-fw fas fa-music"></i><span> Music</span></a></li><li><a class="site-page child" href="/movie/"><i class="fa-fw fas fa-video"></i><span> Movie</span></a></li></ul></div><div class="menus_item"><a class="site-page" href="/about/"><i class="fa-fw fas fa-heart"></i><span> About</span></a></div></div><div id="toggle-menu"><a class="site-page" href="javascript:void(0);"><i class="fas fa-bars fa-fw"></i></a></div></div></nav><div id="post-info"><h1 class="post-title">Pytorch框架深度学习入门基础</h1><div id="post-meta"><div class="meta-firstline"><span class="post-meta-date"><i class="far fa-calendar-alt fa-fw post-meta-icon"></i><span class="post-meta-label">Created</span><time class="post-meta-date-created" datetime="2023-03-14T13:06:41.000Z" title="Created 2023-03-14 21:06:41">2023-03-14</time><span class="post-meta-separator">|</span><i class="fas fa-history fa-fw post-meta-icon"></i><span class="post-meta-label">Updated</span><time class="post-meta-date-updated" datetime="2023-03-15T10:44:35.099Z" title="Updated 2023-03-15 18:44:35">2023-03-15</time></span><span class="post-meta-categories"><span class="post-meta-separator">|</span><i class="fas fa-inbox fa-fw post-meta-icon"></i><a class="post-meta-categories" href="/categories/Deeplearn/">Deeplearn</a></span></div><div class="meta-secondline"><span class="post-meta-separator">|</span><span class="post-meta-wordcount"><i class="far fa-file-word fa-fw post-meta-icon"></i><span class="post-meta-label">Word count:</span><span class="word-count">3.6k</span><span class="post-meta-separator">|</span><i class="far fa-clock fa-fw post-meta-icon"></i><span class="post-meta-label">Reading time:</span><span>19min</span></span><span class="post-meta-separator">|</span><span class="post-meta-pv-cv" id="" data-flag-title="Pytorch框架深度学习入门基础"><i class="far fa-eye fa-fw post-meta-icon"></i><span class="post-meta-label">Post View:</span><span id="busuanzi_value_page_pv"><i class="fa-solid fa-spinner fa-spin"></i></span></span></div></div></div></header><main class="layout" id="content-inner"><div id="post"><article class="post-content" id="article-container"><h2 id="首先声明！！！">首先声明！！！</h2>
<hr>
<ul>
<li>1.脚本为本人总结，如有使用注明出处。</li>
<li>2.Pytorch基于Python编程语言编写脚本。</li>
<li>3.脚本内有注释。</li>
</ul>
<hr>
<h2 id="一、main-demo">一、main_demo</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 这是一个示例 Python 脚本。</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 按 Shift+F10 执行或将其替换为您的代码。</span></span><br><span class="line"><span class="comment"># 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">print_hi</span>(<span class="params">name</span>):</span><br><span class="line">    <span class="comment"># 在下面的代码行中使用断点来调试脚本。</span></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">f&#x27;Hi, <span class="subst">&#123;name&#125;</span>&#x27;</span>)  <span class="comment"># 按 Ctrl+F8 切换断点。</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 按间距中的绿色按钮以运行脚本。</span></span><br><span class="line"><span class="keyword">if</span> __name__ == <span class="string">&#x27;__main__&#x27;</span>:</span><br><span class="line">    print_hi(<span class="string">&#x27;PyCharm&#x27;</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助</span></span><br></pre></td></tr></table></figure>
<h2 id="二、read-data（读取数据）">二、read_data（读取数据）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> Dataset</span><br><span class="line"><span class="keyword">from</span> PIL <span class="keyword">import</span> Image</span><br><span class="line"><span class="keyword">import</span> os</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">MyData</span>(<span class="title class_ inherited__">Dataset</span>):</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 初始化函数</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, root_dir, label_dir</span>):</span><br><span class="line">        self.root_dir = root_dir</span><br><span class="line">        self.label_dir = label_dir</span><br><span class="line">        self.path = os.path.join(self.root_dir, self.label_dir)</span><br><span class="line">        self.img_path = os.listdir(self.path)</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 获取数据集</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__getitem__</span>(<span class="params">self, idx</span>):</span><br><span class="line">        img_name = self.img_path[idx]</span><br><span class="line">        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)</span><br><span class="line">        img = Image.<span class="built_in">open</span>(img_item_path)</span><br><span class="line">        lable = self.label_dir</span><br><span class="line">        <span class="keyword">return</span> img, lable</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 获取长度</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__len__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="keyword">return</span> <span class="built_in">len</span>(self.img_path)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">root_dir = <span class="string">&quot;dataset/hymenoptera_data/train&quot;</span></span><br><span class="line"></span><br><span class="line">ants_label_dir = <span class="string">&quot;ants&quot;</span></span><br><span class="line">bees_label_dir = <span class="string">&quot;bees&quot;</span></span><br><span class="line"></span><br><span class="line">ants_dataset = MyData(root_dir, ants_label_dir)</span><br><span class="line">bees_dataset = MyData(root_dir, bees_label_dir)</span><br><span class="line"></span><br><span class="line">train_dataset = ants_dataset + bees_dataset</span><br><span class="line"></span><br></pre></td></tr></table></figure>
<h2 id="三、test-tb（简单测试）">三、test_tb（简单测试）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch.utils.tensorboard <span class="keyword">import</span> SummaryWriter</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">from</span> PIL <span class="keyword">import</span> Image</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">writer = SummaryWriter(<span class="string">&quot;logs&quot;</span>)</span><br><span class="line">image_path = <span class="string">&quot;dataset/hymenoptera_data/train/bees/16838648_415acd9e3f.jpg&quot;</span></span><br><span class="line">img_PIL = Image.<span class="built_in">open</span>(image_path)</span><br><span class="line">img_array = np.array(img_PIL)</span><br><span class="line"><span class="comment"># print(type(img_array))</span></span><br><span class="line"><span class="comment"># print(img_array.shape)</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">writer.add_image(<span class="string">&quot;test&quot;</span>, img_array, <span class="number">2</span>, dataformats=<span class="string">&#x27;HWC&#x27;</span>)</span><br><span class="line"><span class="comment"># y = x</span></span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">100</span>):</span><br><span class="line">    writer.add_scalar(<span class="string">&quot;y=2x&quot;</span>, <span class="number">2</span>*i, i)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">writer.close()</span><br></pre></td></tr></table></figure>
<h3 id="可视化页面展示：">可视化页面展示：</h3>
<p><img src="/2023/03/14/Pytorch%E6%A1%86%E6%9E%B6%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8%E5%9F%BA%E7%A1%80/test1.png" alt="test1"></p>
<h2 id="四、Transforms（数据转换）">四、Transforms（数据转换）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> PIL <span class="keyword">import</span> Image</span><br><span class="line"><span class="keyword">from</span> torchvision <span class="keyword">import</span> transforms</span><br><span class="line"><span class="keyword">from</span> torch.utils.tensorboard <span class="keyword">import</span> SummaryWriter</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># python的用法 -&gt; tensor数据类型</span></span><br><span class="line"><span class="comment"># 通过 transforms.ToTensor看两个问题</span></span><br><span class="line"><span class="comment"># 1. transforms该如何使用？</span></span><br><span class="line"><span class="comment"># 2. 为什么我们需要Tensor数据类型</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 绝对路径 C:\Users\Administrator\Desktop\编程代码\Python\pytorch深度学习\dataset\hymenoptera_data\train\ants\0013035.jpg</span></span><br><span class="line"><span class="comment"># 相对路径 dataset/hymenoptera_data/train/ants/0013035.jpg</span></span><br><span class="line"></span><br><span class="line">img_path = <span class="string">&quot;dataset/hymenoptera_data/train/ants/0013035.jpg&quot;</span></span><br><span class="line">img = Image.<span class="built_in">open</span>(img_path)</span><br><span class="line"></span><br><span class="line">writer = SummaryWriter(<span class="string">&quot;logs&quot;</span>)</span><br><span class="line"></span><br><span class="line">tensor_trans = transforms.ToTensor()</span><br><span class="line">tensor_img = tensor_trans(img)</span><br><span class="line">writer.add_image(<span class="string">&quot;Tensor_img&quot;</span>, tensor_img)</span><br><span class="line"></span><br><span class="line">writer.close()</span><br></pre></td></tr></table></figure>
<h2 id="五、UseTransforms（使用数据转换）">五、UseTransforms（使用数据转换）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> PIL <span class="keyword">import</span> Image</span><br><span class="line"><span class="keyword">from</span> torch.utils.tensorboard <span class="keyword">import</span> SummaryWriter</span><br><span class="line"><span class="keyword">from</span> torchvision <span class="keyword">import</span> transforms</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">writer = SummaryWriter(<span class="string">&quot;logs&quot;</span>)</span><br><span class="line">img = Image.<span class="built_in">open</span>(<span class="string">&quot;images/桌面.jpg&quot;</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># TpTensor的使用</span></span><br><span class="line">trans_totensor = transforms.ToTensor()</span><br><span class="line">img_tensor = trans_totensor(img)</span><br><span class="line">writer.add_image(<span class="string">&quot;Totensor&quot;</span>, img_tensor)</span><br><span class="line"></span><br><span class="line"><span class="comment"># Normalize</span></span><br><span class="line"><span class="built_in">print</span>(img_tensor[<span class="number">0</span>][<span class="number">0</span>][<span class="number">0</span>])</span><br><span class="line">trans_norm = transforms.Normalize([<span class="number">6</span>, <span class="number">3</span>, <span class="number">2</span>], [<span class="number">9</span>, <span class="number">3</span>, <span class="number">5</span>])</span><br><span class="line">img_norm = trans_norm(img_tensor)</span><br><span class="line"><span class="built_in">print</span>(img_norm[<span class="number">0</span>][<span class="number">0</span>][<span class="number">0</span>])</span><br><span class="line">writer.add_image(<span class="string">&quot;Normalize&quot;</span>, img_norm, <span class="number">2</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># Resize</span></span><br><span class="line"><span class="built_in">print</span>(img.size)</span><br><span class="line">trans_resize = transforms.Resize((<span class="number">512</span>, <span class="number">512</span>))</span><br><span class="line">img_resize = trans_resize(img)</span><br><span class="line">img_resize = trans_totensor(img_resize)</span><br><span class="line">writer.add_image(<span class="string">&quot;Resize&quot;</span>, img_resize, <span class="number">0</span>)</span><br><span class="line"><span class="built_in">print</span>(img_resize)</span><br><span class="line"></span><br><span class="line"><span class="comment"># Compose - resize - 2</span></span><br><span class="line">trans_resize_2 = transforms.Resize(<span class="number">512</span>)</span><br><span class="line">trans_compose = transforms.Compose([trans_resize_2, trans_totensor])</span><br><span class="line">img_resize_2 = trans_compose(img)</span><br><span class="line">writer.add_image(<span class="string">&quot;Resize&quot;</span>, img_resize_2, <span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># RandomCrop</span></span><br><span class="line">trans_random = transforms.RandomCrop(<span class="number">512</span>)</span><br><span class="line">trans_compose_2 = transforms.Compose([trans_random, trans_totensor])</span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">10</span>):</span><br><span class="line">    img_crop = trans_compose_2(img)</span><br><span class="line">    writer.add_image(<span class="string">&quot;RandomCrop&quot;</span>, img_crop, i)</span><br><span class="line"></span><br><span class="line">writer.close()</span><br></pre></td></tr></table></figure>
<h2 id="六、dataset-transforms（数据集使用）">六、dataset_transforms（数据集使用）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torch.utils.tensorboard <span class="keyword">import</span> SummaryWriter</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">dataset_transform = torchvision.transforms.Compose([</span><br><span class="line">    torchvision.transforms.ToTensor()</span><br><span class="line">])</span><br><span class="line"></span><br><span class="line">train_set = torchvision.datasets.CIFAR10(root=<span class="string">&quot;./dataset&quot;</span>, train=<span class="literal">True</span>, transform=dataset_transform, download=<span class="literal">True</span>)</span><br><span class="line">test_set = torchvision.datasets.CIFAR10(root=<span class="string">&quot;./dataset&quot;</span>, train=<span class="literal">False</span>, transform=dataset_transform, download=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># print(test_set[0])</span></span><br><span class="line"></span><br><span class="line">writer = SummaryWriter(<span class="string">&quot;../p10&quot;</span>)</span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">10</span>):</span><br><span class="line">    img, target = test_set[i]</span><br><span class="line">    writer.add_image((<span class="string">&quot;test_set&quot;</span>, img, i))</span><br><span class="line"></span><br><span class="line">writer.close()</span><br></pre></td></tr></table></figure>
<h2 id="七、read-data（读取数据）">七、read_data（读取数据）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> Dataset</span><br><span class="line"><span class="keyword">from</span> PIL <span class="keyword">import</span> Image</span><br><span class="line"><span class="keyword">import</span> os</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">MyData</span>(<span class="title class_ inherited__">Dataset</span>):</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 初始化函数</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, root_dir, label_dir</span>):</span><br><span class="line">        self.root_dir = root_dir</span><br><span class="line">        self.label_dir = label_dir</span><br><span class="line">        self.path = os.path.join(self.root_dir, self.label_dir)</span><br><span class="line">        self.img_path = os.listdir(self.path)</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 获取数据集</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__getitem__</span>(<span class="params">self, idx</span>):</span><br><span class="line">        img_name = self.img_path[idx]</span><br><span class="line">        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)</span><br><span class="line">        img = Image.<span class="built_in">open</span>(img_item_path)</span><br><span class="line">        lable = self.label_dir</span><br><span class="line">        <span class="keyword">return</span> img, lable</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 获取长度</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__len__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="keyword">return</span> <span class="built_in">len</span>(self.img_path)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">root_dir = <span class="string">&quot;../dataset/hymenoptera_data/train&quot;</span></span><br><span class="line">ants_label_dir = <span class="string">&quot;ants&quot;</span></span><br><span class="line">bees_label_dir = <span class="string">&quot;bees&quot;</span></span><br><span class="line">ants_dataset = MyData(root_dir, ants_label_dir)</span><br><span class="line">bees_dataset = MyData(root_dir, bees_label_dir)</span><br><span class="line"></span><br><span class="line">train_dataset = ants_dataset + bees_dataset</span><br></pre></td></tr></table></figure>
<h2 id="八、nn-module（模型基础）">八、nn_module（模型基础）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Tudui</span>(nn.Module):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="built_in">super</span>().__init__()</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, <span class="built_in">input</span></span>):</span><br><span class="line">        output = <span class="built_in">input</span> + <span class="number">1</span></span><br><span class="line">        <span class="keyword">return</span> output</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">tudui = Tudui()</span><br><span class="line">x = torch.tensor(<span class="number">1.0</span>)</span><br><span class="line">output = tudui(x)</span><br><span class="line"><span class="built_in">print</span>(output)</span><br></pre></td></tr></table></figure>
<h2 id="九、nn-conv2d（卷积层）">九、nn_conv2d（卷积层）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> Conv2d</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"><span class="keyword">from</span> torch.utils.tensorboard <span class="keyword">import</span> SummaryWriter</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">dataset = torchvision.datasets.CIFAR10(<span class="string">&quot;../data&quot;</span>, train=<span class="literal">False</span>,</span><br><span class="line">        transform=torchvision.transforms.ToTensor(), download=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">dataloader = DataLoader(dataset, batch_size=<span class="number">64</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Tudui</span>(nn.Module):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="built_in">super</span>(Tudui, self).__init__()</span><br><span class="line">        self.conv1 = Conv2d(in_channels=<span class="number">3</span>, out_channels=<span class="number">6</span>, kernel_size=<span class="number">3</span>, stride=<span class="number">1</span>, padding=<span class="number">0</span>)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, x</span>):</span><br><span class="line">        x = self.conv1(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br><span class="line"></span><br><span class="line">tudui = Tudui()</span><br><span class="line"></span><br><span class="line">writer = SummaryWriter(<span class="string">&quot;../logs&quot;</span>)</span><br><span class="line"></span><br><span class="line">step = <span class="number">0</span></span><br><span class="line"><span class="keyword">for</span> data <span class="keyword">in</span> dataloader:</span><br><span class="line">    imgs, targets = data</span><br><span class="line">    output = tudui(imgs)</span><br><span class="line">    <span class="comment"># print(imgs.shape)</span></span><br><span class="line">    <span class="comment"># print(output.shape)</span></span><br><span class="line"></span><br><span class="line">    <span class="comment"># torch.Size([64, 3, 32, 32])</span></span><br><span class="line">    writer.add_images(<span class="string">&quot;input&quot;</span>, imgs, step)</span><br><span class="line"></span><br><span class="line">    <span class="comment"># torch.Size([64, 6, 30, 30]) -&gt; [xxx, 3, 30, 30]</span></span><br><span class="line">    output = torch.reshape(output, (-<span class="number">1</span>, <span class="number">3</span>, <span class="number">30</span>, <span class="number">30</span>))</span><br><span class="line">    writer.add_images(<span class="string">&quot;output&quot;</span>, output, step)</span><br><span class="line"></span><br><span class="line">    step += <span class="number">1</span></span><br></pre></td></tr></table></figure>
<h2 id="十、nn-maxpool（最大池化）">十、nn_maxpool（最大池化）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> MaxPool2d</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"><span class="keyword">from</span> torch.utils.tensorboard <span class="keyword">import</span> SummaryWriter</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">dataset = torchvision.datasets.CIFAR10(<span class="string">&quot;../data&quot;</span>, train=<span class="literal">False</span>, download=<span class="literal">True</span>,</span><br><span class="line">                                       transform=torchvision.transforms.ToTensor())</span><br><span class="line"></span><br><span class="line">dataloader = DataLoader(dataset, batch_size=<span class="number">64</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 二维矩阵最大池化</span></span><br><span class="line"><span class="comment"># input = torch.tensor([[1, 2, 0, 3, 1],</span></span><br><span class="line"><span class="comment">#                      [0, 1, 2, 3, 1],</span></span><br><span class="line"><span class="comment">#                      [1, 2, 1, 0, 0],</span></span><br><span class="line"><span class="comment">#                      [5, 2, 3, 1, 1],</span></span><br><span class="line"><span class="comment">#                      [2, 1, 0, 1, 1]], dtype=torch.float32)</span></span><br><span class="line"><span class="comment">#</span></span><br><span class="line"><span class="comment"># input = torch.reshape(input, (-1, 1, 5, 5))</span></span><br><span class="line"><span class="comment"># print(input.shape)</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Tudui</span>(nn.Module):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="built_in">super</span>(Tudui, self).__init__()</span><br><span class="line">        self.maxpool1 = MaxPool2d(kernel_size=<span class="number">3</span>, ceil_mode=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, <span class="built_in">input</span></span>):</span><br><span class="line">        output = self.maxpool1(<span class="built_in">input</span>)</span><br><span class="line">        <span class="keyword">return</span> output</span><br><span class="line"></span><br><span class="line">tudui = Tudui()</span><br><span class="line"></span><br><span class="line">writer = SummaryWriter(<span class="string">&quot;../logs_maxpool&quot;</span>)</span><br><span class="line">step = <span class="number">0</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> data <span class="keyword">in</span> dataloader:</span><br><span class="line">    imgs, targets = data</span><br><span class="line">    writer.add_images(<span class="string">&quot;input&quot;</span>, imgs, step)</span><br><span class="line">    output = tudui(imgs)</span><br><span class="line">    writer.add_images(<span class="string">&quot;output&quot;</span>, output, step)</span><br><span class="line">    step += <span class="number">1</span></span><br><span class="line"></span><br><span class="line">writer.close()</span><br></pre></td></tr></table></figure>
<h2 id="十一、nn-relu（非线性激活）">十一、nn_relu（非线性激活）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> ReLU, Sigmoid</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"><span class="keyword">from</span> torch.utils.tensorboard <span class="keyword">import</span> SummaryWriter</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="built_in">input</span> = torch.tensor([[<span class="number">1</span>, -<span class="number">0.5</span>],</span><br><span class="line">                       [-<span class="number">1</span>, <span class="number">3</span>]])</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="built_in">input</span> = torch.reshape(<span class="built_in">input</span>, (-<span class="number">1</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">2</span>))</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">dataset = torchvision.datasets.CIFAR10(<span class="string">&quot;../data&quot;</span>, train=<span class="literal">False</span>, download=<span class="literal">True</span>,</span><br><span class="line">                                       transform=torchvision.transforms.ToTensor())</span><br><span class="line"></span><br><span class="line">dataloader = DataLoader(dataset, batch_size=<span class="number">64</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Tudui</span>(nn.Module):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="built_in">super</span>(Tudui, self).__init__()</span><br><span class="line">        self.relu1 = ReLU()</span><br><span class="line">        self.sigmoid1 = Sigmoid()</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, <span class="built_in">input</span></span>):</span><br><span class="line">        output = self.sigmoid1(<span class="built_in">input</span>)</span><br><span class="line">        <span class="keyword">return</span> output</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">tudui = Tudui()</span><br><span class="line"><span class="comment"># output = tudui(input)</span></span><br><span class="line"><span class="comment"># print(output)</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">writer = SummaryWriter(<span class="string">&quot;../logs_relu&quot;</span>)</span><br><span class="line">step = <span class="number">0</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> data <span class="keyword">in</span> dataloader:</span><br><span class="line">    imgs, targets = data</span><br><span class="line">    writer.add_images(<span class="string">&quot;input&quot;</span>, imgs, step)</span><br><span class="line">    output = tudui(imgs)</span><br><span class="line">    writer.add_images(<span class="string">&quot;output&quot;</span>, output, step)</span><br><span class="line">    step += <span class="number">1</span></span><br><span class="line"></span><br><span class="line">writer.close()</span><br></pre></td></tr></table></figure>
<h2 id="十二、nn-linear（线性层）">十二、nn_linear（线性层）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> Linear</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">dataset = torchvision.datasets.CIFAR10(<span class="string">&quot;../data&quot;</span>, train=<span class="literal">False</span>, download=<span class="literal">True</span>,</span><br><span class="line">                                       transform=torchvision.transforms.ToTensor())</span><br><span class="line"></span><br><span class="line">dataloader = DataLoader(dataset, batch_size=<span class="number">64</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Tudui</span>(nn.Module):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="built_in">super</span>(Tudui, self).__init__()</span><br><span class="line">        self.linear1 = Linear(<span class="number">196608</span>, <span class="number">10</span>)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, <span class="built_in">input</span></span>):</span><br><span class="line">        output = self.linear1(<span class="built_in">input</span>)</span><br><span class="line">        <span class="keyword">return</span> output</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">tudui = Tudui()</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> data <span class="keyword">in</span> dataloader:</span><br><span class="line">    imgs, targets = data</span><br><span class="line">    <span class="built_in">print</span>(imgs.shape)</span><br><span class="line">    output = torch.reshape(imgs, (<span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, -<span class="number">1</span>))</span><br><span class="line">    <span class="built_in">print</span>(output.shape)</span><br><span class="line">    output = tudui(output)</span><br><span class="line">    <span class="built_in">print</span>(output.shape)</span><br></pre></td></tr></table></figure>
<h2 id="十三、nn-seq（搭建小实战）">十三、nn_seq（搭建小实战）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> Conv2d, MaxPool2d, Flatten, Linear, Sequential</span><br><span class="line"><span class="keyword">from</span> torch.utils.tensorboard <span class="keyword">import</span> SummaryWriter</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Tudui</span>(nn.Module):</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="built_in">super</span>(Tudui, self).__init__()</span><br><span class="line"></span><br><span class="line">        <span class="comment"># self.conv1 = Conv2d(3, 32, 5, padding=2)</span></span><br><span class="line">        <span class="comment"># self.maxpool1 = MaxPool2d(2)</span></span><br><span class="line">        <span class="comment"># self.conv2 = Conv2d(32, 32, 5, padding=2)</span></span><br><span class="line">        <span class="comment"># self.maxpool2 = MaxPool2d(2)</span></span><br><span class="line">        <span class="comment"># self.conv3 = Conv2d(32, 64, 5, padding=2)</span></span><br><span class="line">        <span class="comment"># self.maxpool3 = MaxPool2d(2)</span></span><br><span class="line">        <span class="comment"># self.flatten = Flatten()</span></span><br><span class="line">        <span class="comment"># self.linear1 = Linear(1024, 64)</span></span><br><span class="line">        <span class="comment"># self.linear2 = Linear(64, 10)</span></span><br><span class="line"></span><br><span class="line">        self.model1 = Sequential(</span><br><span class="line">            Conv2d(<span class="number">3</span>, <span class="number">32</span>, <span class="number">5</span>, padding=<span class="number">2</span>),</span><br><span class="line">            MaxPool2d(<span class="number">2</span>),</span><br><span class="line">            Conv2d(<span class="number">32</span>, <span class="number">32</span>, <span class="number">5</span>, padding=<span class="number">2</span>),</span><br><span class="line">            MaxPool2d(<span class="number">2</span>),</span><br><span class="line">            Conv2d(<span class="number">32</span>, <span class="number">64</span>, <span class="number">5</span>, padding=<span class="number">2</span>),</span><br><span class="line">            MaxPool2d(<span class="number">2</span>),</span><br><span class="line">            Flatten(),</span><br><span class="line">            Linear(<span class="number">1024</span>, <span class="number">64</span>),</span><br><span class="line">            Linear(<span class="number">64</span>, <span class="number">10</span>)</span><br><span class="line">        )</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, x</span>):</span><br><span class="line"></span><br><span class="line">        <span class="comment"># x = self.conv1(x)</span></span><br><span class="line">        <span class="comment"># x = self.maxpool1(x)</span></span><br><span class="line">        <span class="comment"># x = self.conv2(x)</span></span><br><span class="line">        <span class="comment"># x = self.maxpool2(x)</span></span><br><span class="line">        <span class="comment"># x = self.conv3(x)</span></span><br><span class="line">        <span class="comment"># x = self.maxpool3(x)</span></span><br><span class="line">        <span class="comment"># x = self.flatten(x)</span></span><br><span class="line">        <span class="comment"># x = self.linear1(x)</span></span><br><span class="line">        <span class="comment"># x = self.linear2(x)</span></span><br><span class="line"></span><br><span class="line">        x = self.model1(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">tudui = Tudui()</span><br><span class="line"><span class="built_in">print</span>(tudui)</span><br><span class="line"></span><br><span class="line"><span class="built_in">input</span> = torch.ones((<span class="number">64</span>, <span class="number">3</span>, <span class="number">32</span>, <span class="number">32</span>))</span><br><span class="line">output = tudui(<span class="built_in">input</span>)</span><br><span class="line"><span class="built_in">print</span>(output.shape)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">writer = SummaryWriter(<span class="string">&quot;../logs_seq&quot;</span>)</span><br><span class="line">writer.add_graph(tudui, <span class="built_in">input</span>)</span><br><span class="line"></span><br><span class="line">writer.close()</span><br></pre></td></tr></table></figure>
<h2 id="十四、损失函数与反向传播">十四、损失函数与反向传播</h2>
<h3 id="1-nn-loss">1.nn_loss</h3>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> L1Loss</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">inputs = torch.tensor([<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>], dtype=torch.float32)</span><br><span class="line">targets = torch.tensor([<span class="number">1</span>, <span class="number">2</span>, <span class="number">5</span>], dtype=torch.float32)</span><br><span class="line"></span><br><span class="line">inputs = torch.reshape(inputs, (<span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">3</span>))</span><br><span class="line">targets = torch.reshape(targets, (<span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">3</span>))</span><br><span class="line"></span><br><span class="line">loss = L1Loss(reduction=<span class="string">&#x27;sum&#x27;</span>)</span><br><span class="line">result = loss(inputs, targets)</span><br><span class="line"></span><br><span class="line">loss_mse = nn.MSELoss()</span><br><span class="line">result_mse = loss_mse(inputs, targets)</span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span>(result)</span><br><span class="line"><span class="built_in">print</span>(result_mse)</span><br></pre></td></tr></table></figure>
<h3 id="2-nn-loss-network">2.nn_loss_network</h3>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> Conv2d, MaxPool2d, Flatten, Linear, Sequential</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"><span class="keyword">from</span> torch.utils.tensorboard <span class="keyword">import</span> SummaryWriter</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 导入数据集</span></span><br><span class="line">dataset = torchvision.datasets.CIFAR10(<span class="string">&quot;../data&quot;</span>, train=<span class="literal">False</span>,</span><br><span class="line">        transform=torchvision.transforms.ToTensor(), download=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">dataloader = DataLoader(dataset, batch_size=<span class="number">64</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 神经网络模板</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Tudui</span>(nn.Module):</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="built_in">super</span>(Tudui, self).__init__()</span><br><span class="line">        self.model1 = Sequential(</span><br><span class="line">            Conv2d(<span class="number">3</span>, <span class="number">32</span>, <span class="number">5</span>, padding=<span class="number">2</span>),</span><br><span class="line">            MaxPool2d(<span class="number">2</span>),</span><br><span class="line">            Conv2d(<span class="number">32</span>, <span class="number">32</span>, <span class="number">5</span>, padding=<span class="number">2</span>),</span><br><span class="line">            MaxPool2d(<span class="number">2</span>),</span><br><span class="line">            Conv2d(<span class="number">32</span>, <span class="number">64</span>, <span class="number">5</span>, padding=<span class="number">2</span>),</span><br><span class="line">            MaxPool2d(<span class="number">2</span>),</span><br><span class="line">            Flatten(),</span><br><span class="line">            Linear(<span class="number">1024</span>, <span class="number">64</span>),</span><br><span class="line">            Linear(<span class="number">64</span>, <span class="number">10</span>)</span><br><span class="line">        )</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, x</span>):</span><br><span class="line">        x = self.model1(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 计算损失</span></span><br><span class="line">loss = nn.CrossEntropyLoss()</span><br><span class="line">tudui = Tudui()</span><br><span class="line"><span class="keyword">for</span> data <span class="keyword">in</span> dataloader:</span><br><span class="line">    imgs, targets = data</span><br><span class="line">    outputs = tudui(imgs)</span><br><span class="line">    result_loss = loss(outputs, targets)</span><br><span class="line">    result_loss.backward()</span><br></pre></td></tr></table></figure>
<h3 id="3-可视化页面展示">3.可视化页面展示</h3>
<p><img src="/2023/03/14/Pytorch%E6%A1%86%E6%9E%B6%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8%E5%9F%BA%E7%A1%80/test2.png" alt="test2"></p>
<h2 id="十五、nn-optim（优化器）">十五、nn_optim（优化器）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> Conv2d, MaxPool2d, Flatten, Linear, Sequential</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"><span class="keyword">from</span> torch.utils.tensorboard <span class="keyword">import</span> SummaryWriter</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 导入数据集</span></span><br><span class="line">dataset = torchvision.datasets.CIFAR10(<span class="string">&quot;../data&quot;</span>, train=<span class="literal">False</span>,</span><br><span class="line">                                       transform=torchvision.transforms.ToTensor(), download=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">dataloader = DataLoader(dataset, batch_size=<span class="number">64</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 神经网络模板</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Tudui</span>(nn.Module):</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="built_in">super</span>(Tudui, self).__init__()</span><br><span class="line">        self.model1 = Sequential(</span><br><span class="line">            Conv2d(<span class="number">3</span>, <span class="number">32</span>, <span class="number">5</span>, padding=<span class="number">2</span>),</span><br><span class="line">            MaxPool2d(<span class="number">2</span>),</span><br><span class="line">            Conv2d(<span class="number">32</span>, <span class="number">32</span>, <span class="number">5</span>, padding=<span class="number">2</span>),</span><br><span class="line">            MaxPool2d(<span class="number">2</span>),</span><br><span class="line">            Conv2d(<span class="number">32</span>, <span class="number">64</span>, <span class="number">5</span>, padding=<span class="number">2</span>),</span><br><span class="line">            MaxPool2d(<span class="number">2</span>),</span><br><span class="line">            Flatten(),</span><br><span class="line">            Linear(<span class="number">1024</span>, <span class="number">64</span>),</span><br><span class="line">            Linear(<span class="number">64</span>, <span class="number">10</span>)</span><br><span class="line">        )</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, x</span>):</span><br><span class="line">        x = self.model1(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 计算损失、梯度下降</span></span><br><span class="line"></span><br><span class="line">loss = nn.CrossEntropyLoss()</span><br><span class="line">tudui = Tudui()</span><br><span class="line">optim = torch.optim.SGD(tudui.parameters(), lr=<span class="number">0.01</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">20</span>):</span><br><span class="line">    running_loss = <span class="number">0.0</span></span><br><span class="line">    <span class="keyword">for</span> data <span class="keyword">in</span> dataloader:</span><br><span class="line">        imgs, targets = data</span><br><span class="line">        outputs = tudui(imgs)</span><br><span class="line">        result_loss = loss(outputs, targets)</span><br><span class="line">        optim.zero_grad()</span><br><span class="line">        result_loss.backward()</span><br><span class="line">        optim.step()</span><br><span class="line">        running_loss += result_loss</span><br><span class="line">    <span class="built_in">print</span>(running_loss)</span><br></pre></td></tr></table></figure>
<h2 id="十六、model-pretrained（对现有网络模型进行修改-vgg16）">十六、model_pretrained（对现有网络模型进行修改.vgg16）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># train_data = torchvision.datasets.ImageNet(&quot;../data_image_net&quot;, split=&#x27;train&#x27;,</span></span><br><span class="line"><span class="comment">#                                            download=True, transform=torchvision.transforms.ToTensor())</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">vgg16_false = torchvision.models.vgg16(pretrained=<span class="literal">False</span>)</span><br><span class="line">vgg16_true = torchvision.models.vgg16(pretrained=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">train_data = torchvision.datasets.CIFAR10(<span class="string">&quot;../data&quot;</span>, train=<span class="literal">False</span>,</span><br><span class="line">                                       transform=torchvision.transforms.ToTensor(), download=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">vgg16_true.classifier.add_module(<span class="string">&#x27;add_linear&#x27;</span>, nn.Linear(<span class="number">1000</span>, <span class="number">10</span>))</span><br><span class="line"><span class="built_in">print</span>(vgg16_true)</span><br><span class="line"></span><br><span class="line">vgg16_false.classifier[<span class="number">6</span>] = nn.Linear(<span class="number">4096</span>, <span class="number">10</span>)</span><br><span class="line"><span class="built_in">print</span>(vgg16_false)</span><br></pre></td></tr></table></figure>
<h2 id="十七、模型的保存与读取">十七、模型的保存与读取</h2>
<h3 id="1-model-save">1.model_save</h3>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">vgg16 = torchvision.models.vgg16(pretrained=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 保存方式1:模型结构+模型参数</span></span><br><span class="line">torch.save(vgg16, <span class="string">&quot;vgg16_method1.pth&quot;</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 保存方式2:模型参数（官方推荐）</span></span><br><span class="line">torch.save(vgg16.state_dict(), <span class="string">&quot;vgg16_method2.pth&quot;</span>)</span><br></pre></td></tr></table></figure>
<h3 id="2-model-load">2.model_load</h3>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 方式1：加载模型</span></span><br><span class="line"><span class="comment"># model = torch.load(&quot;vgg16_method1.pth&quot;)</span></span><br><span class="line"><span class="comment"># print(model)</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 方式2：加载模型</span></span><br><span class="line">vgg16 = torchvision.models.vgg16(pretrained=<span class="literal">False</span>)</span><br><span class="line">vgg16.load_state_dict(torch.load(<span class="string">&quot;vgg16_method2.pth&quot;</span>))</span><br><span class="line"><span class="comment"># model = torch.load(&quot;vgg16_method2.pth&quot;)</span></span><br><span class="line"><span class="built_in">print</span>(vgg16)</span><br></pre></td></tr></table></figure>
<h2 id="十八、train（完整模型训练套路）">十八、train（完整模型训练套路）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch.optim</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> Sequential, Conv2d, MaxPool2d, Flatten, Linear</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"><span class="keyword">from</span> torch.utils.tensorboard <span class="keyword">import</span> SummaryWriter</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> model <span class="keyword">import</span> *</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 准备数据集</span></span><br><span class="line">train_data = torchvision.datasets.CIFAR10(root=<span class="string">&quot;../data&quot;</span>, train=<span class="literal">True</span>,</span><br><span class="line">                                           download=<span class="literal">True</span>, transform=torchvision.transforms.ToTensor())</span><br><span class="line"></span><br><span class="line">test_data = torchvision.datasets.CIFAR10(root=<span class="string">&quot;../data&quot;</span>, train=<span class="literal">False</span>,</span><br><span class="line">                                           download=<span class="literal">True</span>, transform=torchvision.transforms.ToTensor())</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 数据集长度</span></span><br><span class="line">train_data_size = <span class="built_in">len</span>(train_data)</span><br><span class="line">test_data_size = <span class="built_in">len</span>(test_data)</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;训练数据集的长度为：&#123;&#125;&quot;</span>.<span class="built_in">format</span>(train_data_size))</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;测试数据集的长度为：&#123;&#125;&quot;</span>.<span class="built_in">format</span>(test_data_size))</span><br><span class="line"></span><br><span class="line"><span class="comment"># 利用 DataLoader 来加载数据集</span></span><br><span class="line">train_dataloader = DataLoader(train_data, batch_size=<span class="number">64</span>)</span><br><span class="line">test_dataloader = DataLoader(test_data, batch_size=<span class="number">64</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 创建网络模型</span></span><br><span class="line">tudui = Tudui()</span><br><span class="line"></span><br><span class="line"><span class="comment"># 损失函数</span></span><br><span class="line">loss_fn = nn.CrossEntropyLoss()</span><br><span class="line"></span><br><span class="line"><span class="comment"># 优化器</span></span><br><span class="line">learning_rate = <span class="number">0.01</span></span><br><span class="line">optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 设置训练网络的参数</span></span><br><span class="line"><span class="comment"># 记录训练的次数</span></span><br><span class="line">total_train_step = <span class="number">0</span></span><br><span class="line"><span class="comment"># 记录测试的次数</span></span><br><span class="line">total_test_step = <span class="number">0</span></span><br><span class="line"><span class="comment"># 训练的轮数</span></span><br><span class="line">epoch = <span class="number">10</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 添加tensorboard</span></span><br><span class="line">writer = SummaryWriter(<span class="string">&quot;../logs_train&quot;</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 开始训练</span></span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(epoch):</span><br><span class="line"></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;-----第&#123;&#125;轮训练开始-----&quot;</span>.<span class="built_in">format</span>(i+<span class="number">1</span>))</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 训练步骤开始</span></span><br><span class="line">    tudui.train()</span><br><span class="line">    <span class="keyword">for</span> data <span class="keyword">in</span> train_dataloader:</span><br><span class="line">        imgs, targets = data</span><br><span class="line">        outputs = tudui(imgs)</span><br><span class="line">        loss = loss_fn(outputs,targets)</span><br><span class="line">        <span class="comment"># 优化器优化模型</span></span><br><span class="line">        optimizer.zero_grad()</span><br><span class="line">        loss.backward()</span><br><span class="line">        optimizer.step()</span><br><span class="line">        <span class="comment"># 显示次数和损失loss</span></span><br><span class="line">        total_train_step += <span class="number">1</span></span><br><span class="line">        <span class="keyword">if</span> total_train_step % <span class="number">100</span> == <span class="number">0</span>:</span><br><span class="line">            <span class="built_in">print</span>(<span class="string">&quot;训练次数:&#123;&#125;, Loss:&#123;&#125;&quot;</span>.<span class="built_in">format</span>(total_train_step, loss.item()))</span><br><span class="line">            writer.add_scalar(<span class="string">&quot;train_loss&quot;</span>, loss.item(), total_train_step)</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 测试步骤开始</span></span><br><span class="line">    tudui.<span class="built_in">eval</span>()</span><br><span class="line">    total_test_loss = <span class="number">0</span></span><br><span class="line">    total_accuracy = <span class="number">0</span></span><br><span class="line">    <span class="keyword">with</span> torch.no_grad():</span><br><span class="line">        <span class="keyword">for</span> data <span class="keyword">in</span> test_dataloader:</span><br><span class="line">            imgs, targets = data</span><br><span class="line">            outputs = tudui(imgs)</span><br><span class="line">            loss = loss_fn(outputs, targets)</span><br><span class="line">            total_test_loss += loss.item()</span><br><span class="line">            accuracy = (outputs.argmax(<span class="number">1</span>) == targets).<span class="built_in">sum</span>()</span><br><span class="line">            total_accuracy += accuracy</span><br><span class="line"></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;整体测试集上的Loss:&#123;&#125;&quot;</span>.<span class="built_in">format</span>(total_test_loss))</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;整体测试集上的正确率:&#123;&#125;&quot;</span>.<span class="built_in">format</span>(total_accuracy/test_data_size))</span><br><span class="line">    writer.add_scalar(<span class="string">&quot;teat_loss&quot;</span>, total_test_loss, total_test_step)</span><br><span class="line">    writer.add_scalar(<span class="string">&quot;test_accuracy&quot;</span>, total_accuracy/test_data_size, total_test_step)</span><br><span class="line">    total_test_step += <span class="number">1</span></span><br><span class="line"></span><br><span class="line">    <span class="comment"># torch.save(tudui, &quot;model_classify_&#123;&#125;.pth&quot;.format(i))</span></span><br><span class="line">    <span class="comment"># print(&quot;模型已保存！&quot;)</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">writer.close()</span><br></pre></td></tr></table></figure>
<h3 id="可视化页面展示：-2">可视化页面展示：</h3>
<p><img src="/2023/03/14/Pytorch%E6%A1%86%E6%9E%B6%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8%E5%9F%BA%E7%A1%80/test3.png" alt="test3"></p>
<h2 id="十九、利用GPU训练">十九、利用GPU训练</h2>
<h3 id="1-train-gpu-1">1.train_gpu_1</h3>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> time</span><br><span class="line"><span class="keyword">from</span> model <span class="keyword">import</span> *</span><br><span class="line"><span class="keyword">import</span> torch.optim</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> Sequential, Conv2d, MaxPool2d, Flatten, Linear</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"><span class="keyword">from</span> torch.utils.tensorboard <span class="keyword">import</span> SummaryWriter</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 准备数据集</span></span><br><span class="line">train_data = torchvision.datasets.CIFAR10(root=<span class="string">&quot;../data&quot;</span>, train=<span class="literal">True</span>,</span><br><span class="line">                                           download=<span class="literal">True</span>, transform=torchvision.transforms.ToTensor())</span><br><span class="line"></span><br><span class="line">test_data = torchvision.datasets.CIFAR10(root=<span class="string">&quot;../data&quot;</span>, train=<span class="literal">False</span>,</span><br><span class="line">                                           download=<span class="literal">True</span>, transform=torchvision.transforms.ToTensor())</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 数据集长度</span></span><br><span class="line">train_data_size = <span class="built_in">len</span>(train_data)</span><br><span class="line">test_data_size = <span class="built_in">len</span>(test_data)</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;训练数据集的长度为：&#123;&#125;&quot;</span>.<span class="built_in">format</span>(train_data_size))</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;测试数据集的长度为：&#123;&#125;&quot;</span>.<span class="built_in">format</span>(test_data_size))</span><br><span class="line"></span><br><span class="line"><span class="comment"># 利用 DataLoader 来加载数据集</span></span><br><span class="line">train_dataloader = DataLoader(train_data, batch_size=<span class="number">64</span>)</span><br><span class="line">test_dataloader = DataLoader(test_data, batch_size=<span class="number">64</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 创建网络模型</span></span><br><span class="line">tudui = Tudui()</span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">    tudui = tudui.cuda()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 损失函数</span></span><br><span class="line">loss_fn = nn.CrossEntropyLoss()</span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">    loss_fn = loss_fn.cuda()</span><br><span class="line"></span><br><span class="line"><span class="comment"># 优化器</span></span><br><span class="line">learning_rate = <span class="number">0.01</span></span><br><span class="line">optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 设置训练网络的参数</span></span><br><span class="line"><span class="comment"># 记录训练的次数</span></span><br><span class="line">total_train_step = <span class="number">0</span></span><br><span class="line"><span class="comment"># 记录测试的次数</span></span><br><span class="line">total_test_step = <span class="number">0</span></span><br><span class="line"><span class="comment"># 训练的轮数</span></span><br><span class="line">epoch = <span class="number">10</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 添加tensorboard</span></span><br><span class="line">writer = SummaryWriter(<span class="string">&quot;../logs_train&quot;</span>)</span><br><span class="line">start_time = time.time()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 开始训练</span></span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(epoch):</span><br><span class="line"></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;-----第&#123;&#125;轮训练开始-----&quot;</span>.<span class="built_in">format</span>(i+<span class="number">1</span>))</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 训练步骤开始</span></span><br><span class="line">    tudui.train()</span><br><span class="line"></span><br><span class="line">    <span class="keyword">for</span> data <span class="keyword">in</span> train_dataloader:</span><br><span class="line">        imgs, targets = data</span><br><span class="line"></span><br><span class="line">        <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">            imgs = imgs.cuda()</span><br><span class="line">            targets = targets.cuda()</span><br><span class="line"></span><br><span class="line">        outputs = tudui(imgs)</span><br><span class="line">        loss = loss_fn(outputs,targets)</span><br><span class="line">        <span class="comment"># 优化器优化模型</span></span><br><span class="line">        optimizer.zero_grad()</span><br><span class="line">        loss.backward()</span><br><span class="line">        optimizer.step()</span><br><span class="line">        <span class="comment"># 显示次数和损失loss</span></span><br><span class="line">        total_train_step += <span class="number">1</span></span><br><span class="line">        <span class="keyword">if</span> total_train_step % <span class="number">1000</span> == <span class="number">0</span>:</span><br><span class="line">            end_time = time.time()</span><br><span class="line">            <span class="built_in">print</span>(end_time - start_time)</span><br><span class="line">            <span class="built_in">print</span>(<span class="string">&quot;训练次数:&#123;&#125;, Loss:&#123;&#125;&quot;</span>.<span class="built_in">format</span>(total_train_step, loss.item()))</span><br><span class="line">            writer.add_scalar(<span class="string">&quot;train_loss&quot;</span>, loss.item(), total_train_step)</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 测试步骤开始</span></span><br><span class="line">    tudui.<span class="built_in">eval</span>()</span><br><span class="line">    total_test_loss = <span class="number">0</span></span><br><span class="line">    total_accuracy = <span class="number">0</span></span><br><span class="line"></span><br><span class="line">    <span class="keyword">with</span> torch.no_grad():</span><br><span class="line">        <span class="keyword">for</span> data <span class="keyword">in</span> test_dataloader:</span><br><span class="line">            imgs, targets = data</span><br><span class="line"></span><br><span class="line">            <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">                imgs = imgs.cuda()</span><br><span class="line">                targets = targets.cuda()</span><br><span class="line"></span><br><span class="line">            outputs = tudui(imgs)</span><br><span class="line">            loss = loss_fn(outputs, targets)</span><br><span class="line">            total_test_loss += loss.item()</span><br><span class="line">            accuracy = (outputs.argmax(<span class="number">1</span>) == targets).<span class="built_in">sum</span>()</span><br><span class="line">            total_accuracy += accuracy</span><br><span class="line"></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;整体测试集上的Loss:&#123;&#125;&quot;</span>.<span class="built_in">format</span>(total_test_loss))</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;整体测试集上的正确率:&#123;&#125;&quot;</span>.<span class="built_in">format</span>(total_accuracy/test_data_size))</span><br><span class="line">    writer.add_scalar(<span class="string">&quot;teat_loss&quot;</span>, total_test_loss, total_test_step)</span><br><span class="line">    writer.add_scalar(<span class="string">&quot;test_accuracy&quot;</span>, total_accuracy/test_data_size, total_test_step)</span><br><span class="line">    total_test_step += <span class="number">1</span></span><br><span class="line"></span><br><span class="line">    <span class="keyword">if</span> i == <span class="number">29</span>:</span><br><span class="line">        torch.save(tudui, <span class="string">&quot;model_&#123;&#125;.pth&quot;</span>.<span class="built_in">format</span>(i+<span class="number">1</span>))</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&quot;模型已保存！&quot;</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">writer.close()</span><br></pre></td></tr></table></figure>
<h3 id="2-train-gpu-2">2.train_gpu_2</h3>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> time</span><br><span class="line"><span class="keyword">from</span> model <span class="keyword">import</span> *</span><br><span class="line"><span class="keyword">import</span> torch.optim</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> Sequential, Conv2d, MaxPool2d, Flatten, Linear</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"><span class="keyword">from</span> torch.utils.tensorboard <span class="keyword">import</span> SummaryWriter</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 定义训练设备</span></span><br><span class="line">device = torch.device(<span class="string">&quot;cuda&quot;</span> <span class="keyword">if</span> torch.cuda.is_available() <span class="keyword">else</span> <span class="string">&quot;cpu&quot;</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 准备数据集</span></span><br><span class="line">train_data = torchvision.datasets.CIFAR10(root=<span class="string">&quot;../data&quot;</span>, train=<span class="literal">True</span>,</span><br><span class="line">                download=<span class="literal">True</span>, transform=torchvision.transforms.ToTensor())</span><br><span class="line"></span><br><span class="line">test_data = torchvision.datasets.CIFAR10(root=<span class="string">&quot;../data&quot;</span>, train=<span class="literal">False</span>,</span><br><span class="line">                download=<span class="literal">True</span>, transform=torchvision.transforms.ToTensor())</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 数据集长度</span></span><br><span class="line">train_data_size = <span class="built_in">len</span>(train_data)</span><br><span class="line">test_data_size = <span class="built_in">len</span>(test_data)</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;训练数据集的长度为：&#123;&#125;&quot;</span>.<span class="built_in">format</span>(train_data_size))</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;测试数据集的长度为：&#123;&#125;&quot;</span>.<span class="built_in">format</span>(test_data_size))</span><br><span class="line"></span><br><span class="line"><span class="comment"># 利用 DataLoader 来加载数据集</span></span><br><span class="line">train_dataloader = DataLoader(train_data, batch_size=<span class="number">64</span>)</span><br><span class="line">test_dataloader = DataLoader(test_data, batch_size=<span class="number">64</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 创建网络模型</span></span><br><span class="line">tudui = Tudui()</span><br><span class="line">tudui.to(device)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 损失函数</span></span><br><span class="line">loss_fn = nn.CrossEntropyLoss()</span><br><span class="line">loss_fn.to(device)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 优化器</span></span><br><span class="line">learning_rate = <span class="number">0.01</span></span><br><span class="line">optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 设置训练网络的参数</span></span><br><span class="line"><span class="comment"># 记录训练的次数</span></span><br><span class="line">total_train_step = <span class="number">0</span></span><br><span class="line"><span class="comment"># 记录测试的次数</span></span><br><span class="line">total_test_step = <span class="number">0</span></span><br><span class="line"><span class="comment"># 训练的轮数</span></span><br><span class="line">epoch = <span class="number">30</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 添加tensorboard</span></span><br><span class="line">writer = SummaryWriter(<span class="string">&quot;../logs_train&quot;</span>)</span><br><span class="line">start_time = time.time()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 开始训练</span></span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(epoch):</span><br><span class="line"></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;-----第&#123;&#125;轮训练开始-----&quot;</span>.<span class="built_in">format</span>(i+<span class="number">1</span>))</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 训练步骤开始</span></span><br><span class="line">    tudui.train()</span><br><span class="line"></span><br><span class="line">    <span class="keyword">for</span> data <span class="keyword">in</span> train_dataloader:</span><br><span class="line">        imgs, targets = data</span><br><span class="line"></span><br><span class="line">        imgs = imgs.to(device)</span><br><span class="line">        targets = targets.to(device)</span><br><span class="line"></span><br><span class="line">        outputs = tudui(imgs)</span><br><span class="line">        loss = loss_fn(outputs, targets)</span><br><span class="line">        <span class="comment"># 优化器优化模型</span></span><br><span class="line">        optimizer.zero_grad()</span><br><span class="line">        loss.backward()</span><br><span class="line">        optimizer.step()</span><br><span class="line">        <span class="comment"># 显示次数和损失loss</span></span><br><span class="line">        total_train_step += <span class="number">1</span></span><br><span class="line">        <span class="keyword">if</span> total_train_step % <span class="number">100</span> == <span class="number">0</span>:</span><br><span class="line">            writer.add_scalar(<span class="string">&quot;train_loss&quot;</span>, loss.item(), total_train_step)</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 测试步骤开始</span></span><br><span class="line">    tudui.<span class="built_in">eval</span>()</span><br><span class="line">    total_test_loss = <span class="number">0</span></span><br><span class="line">    total_accuracy = <span class="number">0</span></span><br><span class="line"></span><br><span class="line">    <span class="keyword">with</span> torch.no_grad():</span><br><span class="line">        <span class="keyword">for</span> data <span class="keyword">in</span> test_dataloader:</span><br><span class="line">            imgs, targets = data</span><br><span class="line"></span><br><span class="line">            imgs = imgs.to(device)</span><br><span class="line">            targets = targets.to(device)</span><br><span class="line"></span><br><span class="line">            outputs = tudui(imgs)</span><br><span class="line">            loss = loss_fn(outputs, targets)</span><br><span class="line">            total_test_loss += loss.item()</span><br><span class="line">            accuracy = (outputs.argmax(<span class="number">1</span>) == targets).<span class="built_in">sum</span>()</span><br><span class="line">            total_accuracy += accuracy</span><br><span class="line"></span><br><span class="line">    end_time = time.time()</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;time: &#123;&#125;s&quot;</span>.<span class="built_in">format</span>(end_time - start_time))</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;训练次数:&#123;&#125;, Loss:&#123;&#125;&quot;</span>.<span class="built_in">format</span>(total_train_step, loss.item()))</span><br><span class="line"></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;整体测试集上的Loss:&#123;&#125;&quot;</span>.<span class="built_in">format</span>(total_test_loss))</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;整体测试集上的正确率:&#123;&#125;&quot;</span>.<span class="built_in">format</span>(total_accuracy/test_data_size))</span><br><span class="line">    writer.add_scalar(<span class="string">&quot;teat_loss&quot;</span>, total_test_loss, total_test_step)</span><br><span class="line">    writer.add_scalar(<span class="string">&quot;test_accuracy&quot;</span>, total_accuracy/test_data_size, total_test_step)</span><br><span class="line">    total_test_step += <span class="number">1</span></span><br><span class="line"></span><br><span class="line">    <span class="keyword">if</span> i == <span class="number">29</span>:</span><br><span class="line">        torch.save(tudui, <span class="string">&quot;model_&#123;&#125;.pth&quot;</span>.<span class="built_in">format</span>(i + <span class="number">1</span>))</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&quot;模型已保存！&quot;</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">writer.close()</span><br></pre></td></tr></table></figure>
<h2 id="二十、test（模型验证）">二十、test（模型验证）</h2>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> PIL <span class="keyword">import</span> Image</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> Sequential, Conv2d, MaxPool2d, Flatten, Linear</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">image_path = <span class="string">&quot;../images/dog.png&quot;</span></span><br><span class="line">image = Image.<span class="built_in">open</span>(image_path)</span><br><span class="line"></span><br><span class="line">transform = torchvision.transforms.Compose([torchvision.transforms.Resize((<span class="number">32</span>, <span class="number">32</span>)),</span><br><span class="line">                                            torchvision.transforms.ToTensor()])</span><br><span class="line"></span><br><span class="line">image = transform(image)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Tudui</span>(nn.Module):</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="built_in">super</span>(Tudui, self).__init__()</span><br><span class="line">        self.model1 = Sequential(</span><br><span class="line">            Conv2d(<span class="number">3</span>, <span class="number">32</span>, <span class="number">5</span>, padding=<span class="number">2</span>),</span><br><span class="line">            MaxPool2d(<span class="number">2</span>),</span><br><span class="line">            Conv2d(<span class="number">32</span>, <span class="number">32</span>, <span class="number">5</span>, padding=<span class="number">2</span>),</span><br><span class="line">            MaxPool2d(<span class="number">2</span>),</span><br><span class="line">            Conv2d(<span class="number">32</span>, <span class="number">64</span>, <span class="number">5</span>, padding=<span class="number">2</span>),</span><br><span class="line">            MaxPool2d(<span class="number">2</span>),</span><br><span class="line">            Flatten(),</span><br><span class="line">            Linear(<span class="number">1024</span>, <span class="number">64</span>),</span><br><span class="line">            Linear(<span class="number">64</span>, <span class="number">10</span>)</span><br><span class="line">        )</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, x</span>):</span><br><span class="line">        x = self.model1(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br><span class="line"></span><br><span class="line">model = torch.load(<span class="string">&quot;model_30.pth&quot;</span>, map_location=torch.device(<span class="string">&#x27;cpu&#x27;</span>))</span><br><span class="line"></span><br><span class="line">image = torch.reshape(image, (<span class="number">1</span>, <span class="number">3</span>, <span class="number">32</span>, <span class="number">32</span>))</span><br><span class="line">model.<span class="built_in">eval</span>()</span><br><span class="line"><span class="keyword">with</span> torch.no_grad():</span><br><span class="line">    output = model(image)</span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span>(output.argmax(<span class="number">1</span>))</span><br></pre></td></tr></table></figure>
<hr>
</article><div class="post-copyright"><div class="post-copyright__author"><span class="post-copyright-meta">Author: </span><span class="post-copyright-info"><a href="https://lwkblog.gitee.io">1025</a></span></div><div class="post-copyright__type"><span class="post-copyright-meta">Link: </span><span class="post-copyright-info"><a href="https://lwkblog.gitee.io/2023/03/14/Pytorch%E6%A1%86%E6%9E%B6%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8%E5%9F%BA%E7%A1%80/">https://lwkblog.gitee.io/2023/03/14/Pytorch%E6%A1%86%E6%9E%B6%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8%E5%9F%BA%E7%A1%80/</a></span></div><div class="post-copyright__notice"><span class="post-copyright-meta">Copyright Notice: </span><span class="post-copyright-info">All articles in this blog are licensed under <a target="_blank" rel="noopener" href="https://creativecommons.org/licenses/by-nc-sa/4.0/">CC BY-NC-SA 4.0</a> unless stating additionally.</span></div></div><div class="tag_share"><div class="post-meta__tag-list"></div><div class="post_share"><div class="social-share" data-image="https://pic1.zhimg.com/v2-5ea151d43c0ebb13546985d225ac256a_1200x500.jpg" data-sites="facebook,twitter,wechat,weibo,qq"></div><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/butterfly-extsrc/sharejs/dist/css/share.min.css" media="print" onload="this.media='all'"><script src="https://cdn.jsdelivr.net/npm/butterfly-extsrc/sharejs/dist/js/social-share.min.js" defer></script></div></div><div class="post-reward"><div class="reward-button"><i class="fas fa-qrcode"></i> Donate</div><div class="reward-main"><ul class="reward-all"><li class="reward-item"><a href="/img/wechat.jpg" target="_blank"><img class="post-qr-code-img" src="/img/wechat.jpg" alt="wechat"/></a><div class="post-qr-code-desc">wechat</div></li></ul></div></div><hr/><div id="post-comment"><div class="comment-head"><div class="comment-headline"><i class="fas fa-comments fa-fw"></i><span> Comment</span></div><div id="comment-switch"><span class="first-comment">Valine</span><span class="switch-btn"></span><span class="second-comment">Disqus</span></div></div><div class="comment-wrap"><div><div class="vcomment" id="vcomment"></div></div><div><div id="disqus_thread"></div></div></div></div></div><div class="aside-content" id="aside-content"><div class="card-widget card-info"><div class="is-center"><div class="avatar-img"><img src="https://pic1.zhimg.com/50/v2-530bac941983a6aae866ab34b6290be2_720w.jpg?source=1940ef5c" onerror="this.onerror=null;this.src='/img/friend_404.gif'" alt="avatar"/></div><div class="author-info__name">1025</div><div class="author-info__description"></div></div><div class="card-info-data site-data is-center"><a href="/archives/"><div class="headline">Articles</div><div class="length-num">12</div></a><a href="/tags/"><div class="headline">Tags</div><div class="length-num">0</div></a><a href="/categories/"><div class="headline">Categories</div><div class="length-num">5</div></a></div><a id="card-info-btn" target="_blank" rel="noopener" href="https://codeforces.com/profile/2607563994"><i class="fa-solid fa-code-compare"></i><span>MyCodeforces</span></a><div class="card-info-social-icons is-center"><a class="social-icon" href="https://gitee.com/qq2607563994" target="_blank" title="MyGit"><i class="fa-brands fa-git-alt"></i></a><a class="social-icon" href="https://www.acwing.com/user/myspace/index/226518/" target="_blank" title="MyAcwing"><i class="fa-solid fa-font"></i></a><a class="social-icon" href="https://www.nowcoder.com/" target="_blank" title="poo"><i class="fa-solid fa-poo"></i></a></div></div><div class="card-widget card-announcement"><div class="item-headline"><i class="fas fa-bullhorn fa-shake"></i><span>Announcement</span></div><div class="announcement_content">Welcome to my blog, here is the world of learning~</div></div><div class="sticky_layout"><div class="card-widget" id="card-toc"><div class="item-headline"><i class="fas fa-stream"></i><span>Catalog</span><span class="toc-percentage"></span></div><div class="toc-content"><ol class="toc"><li class="toc-item toc-level-2"><a class="toc-link" href="#%E9%A6%96%E5%85%88%E5%A3%B0%E6%98%8E%EF%BC%81%EF%BC%81%EF%BC%81"><span class="toc-number">1.</span> <span class="toc-text">首先声明！！！</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E4%B8%80%E3%80%81main-demo"><span class="toc-number">2.</span> <span class="toc-text">一、main_demo</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E4%BA%8C%E3%80%81read-data%EF%BC%88%E8%AF%BB%E5%8F%96%E6%95%B0%E6%8D%AE%EF%BC%89"><span class="toc-number">3.</span> <span class="toc-text">二、read_data（读取数据）</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E4%B8%89%E3%80%81test-tb%EF%BC%88%E7%AE%80%E5%8D%95%E6%B5%8B%E8%AF%95%EF%BC%89"><span class="toc-number">4.</span> <span class="toc-text">三、test_tb（简单测试）</span></a><ol class="toc-child"><li class="toc-item toc-level-3"><a class="toc-link" href="#%E5%8F%AF%E8%A7%86%E5%8C%96%E9%A1%B5%E9%9D%A2%E5%B1%95%E7%A4%BA%EF%BC%9A"><span class="toc-number">4.1.</span> <span class="toc-text">可视化页面展示：</span></a></li></ol></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E5%9B%9B%E3%80%81Transforms%EF%BC%88%E6%95%B0%E6%8D%AE%E8%BD%AC%E6%8D%A2%EF%BC%89"><span class="toc-number">5.</span> <span class="toc-text">四、Transforms（数据转换）</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E4%BA%94%E3%80%81UseTransforms%EF%BC%88%E4%BD%BF%E7%94%A8%E6%95%B0%E6%8D%AE%E8%BD%AC%E6%8D%A2%EF%BC%89"><span class="toc-number">6.</span> <span class="toc-text">五、UseTransforms（使用数据转换）</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E5%85%AD%E3%80%81dataset-transforms%EF%BC%88%E6%95%B0%E6%8D%AE%E9%9B%86%E4%BD%BF%E7%94%A8%EF%BC%89"><span class="toc-number">7.</span> <span class="toc-text">六、dataset_transforms（数据集使用）</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E4%B8%83%E3%80%81read-data%EF%BC%88%E8%AF%BB%E5%8F%96%E6%95%B0%E6%8D%AE%EF%BC%89"><span class="toc-number">8.</span> <span class="toc-text">七、read_data（读取数据）</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E5%85%AB%E3%80%81nn-module%EF%BC%88%E6%A8%A1%E5%9E%8B%E5%9F%BA%E7%A1%80%EF%BC%89"><span class="toc-number">9.</span> <span class="toc-text">八、nn_module（模型基础）</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E4%B9%9D%E3%80%81nn-conv2d%EF%BC%88%E5%8D%B7%E7%A7%AF%E5%B1%82%EF%BC%89"><span class="toc-number">10.</span> <span class="toc-text">九、nn_conv2d（卷积层）</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E5%8D%81%E3%80%81nn-maxpool%EF%BC%88%E6%9C%80%E5%A4%A7%E6%B1%A0%E5%8C%96%EF%BC%89"><span class="toc-number">11.</span> <span class="toc-text">十、nn_maxpool（最大池化）</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E5%8D%81%E4%B8%80%E3%80%81nn-relu%EF%BC%88%E9%9D%9E%E7%BA%BF%E6%80%A7%E6%BF%80%E6%B4%BB%EF%BC%89"><span class="toc-number">12.</span> <span class="toc-text">十一、nn_relu（非线性激活）</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E5%8D%81%E4%BA%8C%E3%80%81nn-linear%EF%BC%88%E7%BA%BF%E6%80%A7%E5%B1%82%EF%BC%89"><span class="toc-number">13.</span> <span class="toc-text">十二、nn_linear（线性层）</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E5%8D%81%E4%B8%89%E3%80%81nn-seq%EF%BC%88%E6%90%AD%E5%BB%BA%E5%B0%8F%E5%AE%9E%E6%88%98%EF%BC%89"><span class="toc-number">14.</span> <span class="toc-text">十三、nn_seq（搭建小实战）</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E5%8D%81%E5%9B%9B%E3%80%81%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0%E4%B8%8E%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD"><span class="toc-number">15.</span> <span class="toc-text">十四、损失函数与反向传播</span></a><ol class="toc-child"><li class="toc-item toc-level-3"><a class="toc-link" href="#1-nn-loss"><span class="toc-number">15.1.</span> <span class="toc-text">1.nn_loss</span></a></li><li class="toc-item toc-level-3"><a class="toc-link" href="#2-nn-loss-network"><span class="toc-number">15.2.</span> <span class="toc-text">2.nn_loss_network</span></a></li><li class="toc-item toc-level-3"><a class="toc-link" href="#3-%E5%8F%AF%E8%A7%86%E5%8C%96%E9%A1%B5%E9%9D%A2%E5%B1%95%E7%A4%BA"><span class="toc-number">15.3.</span> <span class="toc-text">3.可视化页面展示</span></a></li></ol></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E5%8D%81%E4%BA%94%E3%80%81nn-optim%EF%BC%88%E4%BC%98%E5%8C%96%E5%99%A8%EF%BC%89"><span class="toc-number">16.</span> <span class="toc-text">十五、nn_optim（优化器）</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E5%8D%81%E5%85%AD%E3%80%81model-pretrained%EF%BC%88%E5%AF%B9%E7%8E%B0%E6%9C%89%E7%BD%91%E7%BB%9C%E6%A8%A1%E5%9E%8B%E8%BF%9B%E8%A1%8C%E4%BF%AE%E6%94%B9-vgg16%EF%BC%89"><span class="toc-number">17.</span> <span class="toc-text">十六、model_pretrained（对现有网络模型进行修改.vgg16）</span></a></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E5%8D%81%E4%B8%83%E3%80%81%E6%A8%A1%E5%9E%8B%E7%9A%84%E4%BF%9D%E5%AD%98%E4%B8%8E%E8%AF%BB%E5%8F%96"><span class="toc-number">18.</span> <span class="toc-text">十七、模型的保存与读取</span></a><ol class="toc-child"><li class="toc-item toc-level-3"><a class="toc-link" href="#1-model-save"><span class="toc-number">18.1.</span> <span class="toc-text">1.model_save</span></a></li><li class="toc-item toc-level-3"><a class="toc-link" href="#2-model-load"><span class="toc-number">18.2.</span> <span class="toc-text">2.model_load</span></a></li></ol></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E5%8D%81%E5%85%AB%E3%80%81train%EF%BC%88%E5%AE%8C%E6%95%B4%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83%E5%A5%97%E8%B7%AF%EF%BC%89"><span class="toc-number">19.</span> <span class="toc-text">十八、train（完整模型训练套路）</span></a><ol class="toc-child"><li class="toc-item toc-level-3"><a class="toc-link" href="#%E5%8F%AF%E8%A7%86%E5%8C%96%E9%A1%B5%E9%9D%A2%E5%B1%95%E7%A4%BA%EF%BC%9A-2"><span class="toc-number">19.1.</span> <span class="toc-text">可视化页面展示：</span></a></li></ol></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E5%8D%81%E4%B9%9D%E3%80%81%E5%88%A9%E7%94%A8GPU%E8%AE%AD%E7%BB%83"><span class="toc-number">20.</span> <span class="toc-text">十九、利用GPU训练</span></a><ol class="toc-child"><li class="toc-item toc-level-3"><a class="toc-link" href="#1-train-gpu-1"><span class="toc-number">20.1.</span> <span class="toc-text">1.train_gpu_1</span></a></li><li class="toc-item toc-level-3"><a class="toc-link" href="#2-train-gpu-2"><span class="toc-number">20.2.</span> <span class="toc-text">2.train_gpu_2</span></a></li></ol></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E4%BA%8C%E5%8D%81%E3%80%81test%EF%BC%88%E6%A8%A1%E5%9E%8B%E9%AA%8C%E8%AF%81%EF%BC%89"><span class="toc-number">21.</span> <span class="toc-text">二十、test（模型验证）</span></a></li></ol></div></div><div class="card-widget card-recent-post"><div class="item-headline"><i class="fas fa-history"></i><span>Recent Post</span></div><div class="aside-list"><div class="aside-list-item"><a class="thumbnail" href="/2024/03/07/Codeforces%201900-2200/" title="Codeforces 1900-2200"><img src="https://th.bing.com/th/id/R.8764d162d1e9be6b2cf6d348e2da99f0?rik=TLZjyw4Nspa%2b0w&amp;riu=http%3a%2f%2fpic.616pic.com%2fys_bnew_img%2f00%2f62%2f36%2fr9dfcyoyjR.jpg&amp;ehk=SdJuWe8fxNWlX58TeKImDSWatngZpwGh6ann2DJ%2fXN0%3d&amp;risl=&amp;pid=ImgRaw&amp;r=0" onerror="this.onerror=null;this.src='/img/404.jpg'" alt="Codeforces 1900-2200"/></a><div class="content"><a class="title" href="/2024/03/07/Codeforces%201900-2200/" title="Codeforces 1900-2200">Codeforces 1900-2200</a><time datetime="2024-03-07T12:28:21.000Z" title="Created 2024-03-07 20:28:21">2024-03-07</time></div></div><div class="aside-list-item"><a class="thumbnail" href="/2023/12/22/%E8%AE%B0%E5%BD%95%E7%AC%AC%E4%B8%80%E6%AC%A1%E6%89%93icpc%E5%8C%BA%E5%9F%9F%E8%B5%9B%E7%9A%84%E6%84%9F%E5%8F%97/" title="记录第一次打icpc区域赛的感受"><img src="https://tse1-mm.cn.bing.net/th/id/OIP-C.3HHSbVgysU3bljjMphclBwHaE8?rs=1&amp;pid=ImgDetMain" onerror="this.onerror=null;this.src='/img/404.jpg'" alt="记录第一次打icpc区域赛的感受"/></a><div class="content"><a class="title" href="/2023/12/22/%E8%AE%B0%E5%BD%95%E7%AC%AC%E4%B8%80%E6%AC%A1%E6%89%93icpc%E5%8C%BA%E5%9F%9F%E8%B5%9B%E7%9A%84%E6%84%9F%E5%8F%97/" title="记录第一次打icpc区域赛的感受">记录第一次打icpc区域赛的感受</a><time datetime="2023-12-22T00:06:41.000Z" title="Created 2023-12-22 08:06:41">2023-12-22</time></div></div><div class="aside-list-item"><a class="thumbnail" href="/2023/12/19/2023%E8%AE%A1%E7%AE%97%E6%9C%BA%E5%8D%8F%E4%BC%9A%E7%BC%96%E7%A8%8B%E5%A4%A7%E8%B5%9B%E9%A2%98%E8%A7%A3%E4%BB%A3%E7%A0%81/" title="HLC_2023年计算机协会编程大赛题解代码"><img src="https://th.bing.com/th/id/R.8764d162d1e9be6b2cf6d348e2da99f0?rik=TLZjyw4Nspa%2b0w&amp;riu=http%3a%2f%2fpic.616pic.com%2fys_bnew_img%2f00%2f62%2f36%2fr9dfcyoyjR.jpg&amp;ehk=SdJuWe8fxNWlX58TeKImDSWatngZpwGh6ann2DJ%2fXN0%3d&amp;risl=&amp;pid=ImgRaw&amp;r=0" onerror="this.onerror=null;this.src='/img/404.jpg'" alt="HLC_2023年计算机协会编程大赛题解代码"/></a><div class="content"><a class="title" href="/2023/12/19/2023%E8%AE%A1%E7%AE%97%E6%9C%BA%E5%8D%8F%E4%BC%9A%E7%BC%96%E7%A8%8B%E5%A4%A7%E8%B5%9B%E9%A2%98%E8%A7%A3%E4%BB%A3%E7%A0%81/" title="HLC_2023年计算机协会编程大赛题解代码">HLC_2023年计算机协会编程大赛题解代码</a><time datetime="2023-12-19T00:06:41.000Z" title="Created 2023-12-19 08:06:41">2023-12-19</time></div></div><div class="aside-list-item"><a class="thumbnail" href="/2023/06/03/Codeforces%201500-1800/" title="Codeforces 1500-1800"><img src="https://th.bing.com/th/id/R.8764d162d1e9be6b2cf6d348e2da99f0?rik=TLZjyw4Nspa%2b0w&amp;riu=http%3a%2f%2fpic.616pic.com%2fys_bnew_img%2f00%2f62%2f36%2fr9dfcyoyjR.jpg&amp;ehk=SdJuWe8fxNWlX58TeKImDSWatngZpwGh6ann2DJ%2fXN0%3d&amp;risl=&amp;pid=ImgRaw&amp;r=0" onerror="this.onerror=null;this.src='/img/404.jpg'" alt="Codeforces 1500-1800"/></a><div class="content"><a class="title" href="/2023/06/03/Codeforces%201500-1800/" title="Codeforces 1500-1800">Codeforces 1500-1800</a><time datetime="2023-06-03T12:28:21.000Z" title="Created 2023-06-03 20:28:21">2023-06-03</time></div></div><div class="aside-list-item"><a class="thumbnail" href="/2023/06/02/Codeforces%201100-1400/" title="Codeforces 1100-1400"><img src="https://th.bing.com/th/id/R.8764d162d1e9be6b2cf6d348e2da99f0?rik=TLZjyw4Nspa%2b0w&amp;riu=http%3a%2f%2fpic.616pic.com%2fys_bnew_img%2f00%2f62%2f36%2fr9dfcyoyjR.jpg&amp;ehk=SdJuWe8fxNWlX58TeKImDSWatngZpwGh6ann2DJ%2fXN0%3d&amp;risl=&amp;pid=ImgRaw&amp;r=0" onerror="this.onerror=null;this.src='/img/404.jpg'" alt="Codeforces 1100-1400"/></a><div class="content"><a class="title" href="/2023/06/02/Codeforces%201100-1400/" title="Codeforces 1100-1400">Codeforces 1100-1400</a><time datetime="2023-06-02T15:08:21.000Z" title="Created 2023-06-02 23:08:21">2023-06-02</time></div></div></div></div></div></div></main><footer id="footer" style="background-image: url('https://pic1.zhimg.com/v2-5ea151d43c0ebb13546985d225ac256a_1200x500.jpg')"><div id="footer-wrap"><div class="copyright">&copy;2023 - 2024 By 1025</div><div class="framework-info"><span>Framework </span><a target="_blank" rel="noopener" href="https://hexo.io">Hexo</a><span class="footer-separator">|</span><span>Theme </span><a target="_blank" rel="noopener" href="https://github.com/jerryc127/hexo-theme-butterfly">Butterfly</a></div><div class="footer_custom_text">Hi, welcome to my <a href="https://lwkblog.gitee.io/">blog</a>!</div></div></footer></div><div id="rightside"><div id="rightside-config-hide"><button id="readmode" type="button" title="Read Mode"><i class="fas fa-book-open"></i></button><button id="darkmode" type="button" title="Switch Between Light And Dark Mode"><i class="fas fa-adjust"></i></button><button id="hide-aside-btn" type="button" title="Toggle between single-column and double-column"><i class="fas fa-arrows-alt-h"></i></button></div><div id="rightside-config-show"><button id="rightside_config" type="button" title="Setting"><i class="fas fa-cog fa-spin"></i></button><button class="close" id="mobile-toc-button" type="button" title="Table Of Contents"><i class="fas fa-list-ul"></i></button><button id="chat_btn" type="button" title="Chat"><i class="fas fa-sms"></i></button><a id="to_comment" href="#post-comment" title="Scroll To Comments"><i class="fas fa-comments"></i></a><button id="go-up" type="button" title="Back To Top"><span class="scroll-percent"></span><i class="fas fa-arrow-up"></i></button></div></div><div id="local-search"><div class="search-dialog"><nav class="search-nav"><span class="search-dialog-title">Search</span><span id="loading-status"></span><button class="search-close-button"><i class="fas fa-times"></i></button></nav><div class="is-center" id="loading-database"><i class="fas fa-spinner fa-pulse"></i><span>  Loading the Database</span></div><div class="search-wrap"><div id="local-search-input"><div class="local-search-box"><input class="local-search-box--input" placeholder="Search for Posts" type="text"/></div></div><hr/><div id="local-search-results"></div></div></div><div id="search-mask"></div></div><div><script src="/js/utils.js"></script><script src="/js/main.js"></script><script src="https://cdn.jsdelivr.net/npm/@fancyapps/ui/dist/fancybox.umd.min.js"></script><script src="/js/search/local-search.js"></script><div class="js-pjax"><script>if (!window.MathJax) {
  window.MathJax = {
    tex: {
      inlineMath: [ ['$','$'], ["\\(","\\)"]],
      tags: 'ams'
    },
    chtml: {
      scale: 1.1
    },
    options: {
      renderActions: {
        findScript: [10, doc => {
          for (const node of document.querySelectorAll('script[type^="math/tex"]')) {
            const display = !!node.type.match(/; *mode=display/)
            const math = new doc.options.MathItem(node.textContent, doc.inputJax[0], display)
            const text = document.createTextNode('')
            node.parentNode.replaceChild(text, node)
            math.start = {node: text, delim: '', n: 0}
            math.end = {node: text, delim: '', n: 0}
            doc.math.push(math)
          }
        }, ''],
        insertScript: [200, () => {
          document.querySelectorAll('mjx-container').forEach(node => {
            if (node.hasAttribute('display')) {
              btf.wrap(node, 'div', { class: 'mathjax-overflow' })
            } else {
              btf.wrap(node, 'span', { class: 'mathjax-overflow' })
            }
          });
        }, '', false]
      }
    }
  }
  
  const script = document.createElement('script')
  script.src = 'https://cdn.jsdelivr.net/npm/mathjax/es5/tex-mml-chtml.min.js'
  script.id = 'MathJax-script'
  script.async = true
  document.head.appendChild(script)
} else {
  MathJax.startup.document.state(0)
  MathJax.texReset()
  MathJax.typeset()
}</script><script>(() => {
  const $mermaidWrap = document.querySelectorAll('#article-container .mermaid-wrap')
  if ($mermaidWrap.length) {
    window.runMermaid = () => {
      window.loadMermaid = true
      const theme = document.documentElement.getAttribute('data-theme') === 'dark' ? 'dark' : 'default'

      Array.from($mermaidWrap).forEach((item, index) => {
        const mermaidSrc = item.firstElementChild
        const mermaidThemeConfig = '%%{init:{ \'theme\':\'' + theme + '\'}}%%\n'
        const mermaidID = 'mermaid-' + index
        const mermaidDefinition = mermaidThemeConfig + mermaidSrc.textContent
        mermaid.mermaidAPI.render(mermaidID, mermaidDefinition, (svgCode) => {
          mermaidSrc.insertAdjacentHTML('afterend', svgCode)
        })
      })
    }

    const loadMermaid = () => {
      window.loadMermaid ? runMermaid() : getScript('https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js').then(runMermaid)
    }

    window.pjax ? loadMermaid() : document.addEventListener('DOMContentLoaded', loadMermaid)
  }
})()</script><script>function loadValine () {
  function initValine () {
    const valine = new Valine(Object.assign({
      el: '#vcomment',
      appId: 'T4ZmL3yRi9k7xz6QEtS0H2jO-gzGzoHsz',
      appKey: 'dxkZy1hOCexMwxvUS4TBskNa',
      avatar: 'monsterid',
      serverURLs: '',
      emojiMaps: "",
      path: window.location.pathname,
      visitor: false
    }, null))
  }

  if (typeof Valine === 'function') initValine() 
  else getScript('https://cdn.jsdelivr.net/npm/valine/dist/Valine.min.js').then(initValine)
}

if ('Valine' === 'Valine' || !true) {
  if (true) btf.loadComment(document.getElementById('vcomment'),loadValine)
  else setTimeout(loadValine, 0)
} else {
  function loadOtherComment () {
    loadValine()
  }
}</script><script>function loadDisqus () {
  var disqus_config = function () {
    this.page.url = 'https://lwkblog.gitee.io/2023/03/14/Pytorch%E6%A1%86%E6%9E%B6%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8%E5%9F%BA%E7%A1%80/'
    this.page.identifier = '/2023/03/14/Pytorch%E6%A1%86%E6%9E%B6%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8%E5%9F%BA%E7%A1%80/'
    this.page.title = 'Pytorch框架深度学习入门基础'
  };

  window.disqusReset = () => {
    DISQUS.reset({
      reload: true,
      config: disqus_config
    })
  }

  if (window.DISQUS) disqusReset()
  else {
    (function() { 
      var d = document, s = d.createElement('script');
      s.src = 'https://.disqus.com/embed.js';
      s.setAttribute('data-timestamp', +new Date());
      (d.head || d.body).appendChild(s);
    })();
  }

  document.getElementById('darkmode').addEventListener('click', () => {
    setTimeout(() => window.disqusReset(), 200)
  })
}

if ('Valine' === 'Disqus' || !true) {
  if (true) btf.loadComment(document.getElementById('disqus_thread'), loadDisqus)
  else loadDisqus()
} else {
  function loadOtherComment () {
    loadDisqus()
  }
}
</script></div><script src="https://cdn.jsdelivr.net/npm/blueimp-md5/js/md5.min.js"></script><script>window.addEventListener('load', () => {
  const changeContent = (content) => {
    if (content === '') return content

    content = content.replace(/<img.*?src="(.*?)"?[^\>]+>/ig, '[image]') // replace image link
    content = content.replace(/<a[^>]+?href=["']?([^"']+)["']?[^>]*>([^<]+)<\/a>/gi, '[link]') // replace url
    content = content.replace(/<pre><code>.*?<\/pre>/gi, '[code]') // replace code
    content = content.replace(/<[^>]+>/g,"") // remove html tag

    if (content.length > 150) {
      content = content.substring(0,150) + '...'
    }
    return content
  }

  const getIcon = (icon, mail) => {
    if (icon) return icon
    let defaultIcon = '?d=monsterid'
    let iconUrl = `https://gravatar.loli.net/avatar/${md5(mail.toLowerCase()) + defaultIcon}`
    return iconUrl
  }

  const generateHtml = array => {
    let result = ''

    if (array.length) {
      for (let i = 0; i < array.length; i++) {
        result += '<div class=\'aside-list-item\'>'

        if (true) {
          const name = 'src'
          result += `<a href='${array[i].url}' class='thumbnail'><img ${name}='${array[i].avatar}' alt='${array[i].nick}'></a>`
        }

        result += `<div class='content'>
        <a class='comment' href='${array[i].url}' title='${array[i].content}'>${array[i].content}</a>
        <div class='name'><span>${array[i].nick} / </span><time datetime="${array[i].date}">${btf.diffDate(array[i].date, true)}</time></div>
        </div></div>`
      }
    } else {
      result += 'No Comment'
    }

    let $dom = document.querySelector('#card-newest-comments .aside-list')
    $dom.innerHTML= result
    window.lazyLoadInstance && window.lazyLoadInstance.update()
    window.pjax && window.pjax.refresh($dom)
  }

  const getComment = () => {
    const serverURL = 'https://T4ZmL3yR.api.lncldglobal.com'

    var settings = {
      "method": "GET",
      "headers": {
        "X-LC-Id": 'T4ZmL3yRi9k7xz6QEtS0H2jO-gzGzoHsz',
        "X-LC-Key": 'dxkZy1hOCexMwxvUS4TBskNa',
        "Content-Type": "application/json"
      },
    }

    fetch(`${serverURL}/1.1/classes/Comment?limit=6&order=-createdAt`,settings)
      .then(response => response.json())
      .then(data => {
        const valineArray = data.results.map(function (e) {
          return {
            'avatar': getIcon(e.QQAvatar, e.mail),
            'content': changeContent(e.comment),
            'nick': e.nick,
            'url': e.url + '#' + e.objectId,
            'date': e.updatedAt,
          }
        })
        saveToLocal.set('valine-newest-comments', JSON.stringify(valineArray), 10/(60*24))
        generateHtml(valineArray)
      }).catch(e => {
        const $dom = document.querySelector('#card-newest-comments .aside-list')
        $dom.innerHTML= "Unable to get the data, please make sure the settings are correct."
      }) 
  }

  const newestCommentInit = () => {
    if (document.querySelector('#card-newest-comments .aside-list')) {
      const data = saveToLocal.get('valine-newest-comments')
      if (data) {
        generateHtml(JSON.parse(data))
      } else {
        getComment()
      }
    }
  }

  newestCommentInit()
  document.addEventListener('pjax:complete', newestCommentInit)
})</script><script defer="defer" id="fluttering_ribbon" mobile="true" src="https://cdn.jsdelivr.net/npm/butterfly-extsrc/dist/canvas-fluttering-ribbon.min.js"></script><script src="https://cdn.jsdelivr.net/npm/butterfly-extsrc/dist/activate-power-mode.min.js"></script><script>POWERMODE.colorful = true;
POWERMODE.shake = false;
POWERMODE.mobile = true;
document.body.addEventListener('input', POWERMODE);
</script><script src="https://cdn.jsdelivr.net/npm/pjax/pjax.min.js"></script><script>let pjaxSelectors = ["head > title","#config-diff","#body-wrap","#rightside-config-hide","#rightside-config-show",".js-pjax"]

var pjax = new Pjax({
  elements: 'a:not([target="_blank"]):not([href="/music/"])',
  selectors: pjaxSelectors,
  cacheBust: false,
  analytics: false,
  scrollRestoration: false
})

document.addEventListener('pjax:send', function () {

  // removeEventListener scroll 
  window.tocScrollFn && window.removeEventListener('scroll', window.tocScrollFn)
  window.scrollCollect && window.removeEventListener('scroll', scrollCollect)

  document.getElementById('rightside').style.cssText = "opacity: ''; transform: ''"
  
  if (window.aplayers) {
    for (let i = 0; i < window.aplayers.length; i++) {
      if (!window.aplayers[i].options.fixed) {
        window.aplayers[i].destroy()
      }
    }
  }

  typeof typed === 'object' && typed.destroy()

  //reset readmode
  const $bodyClassList = document.body.classList
  $bodyClassList.contains('read-mode') && $bodyClassList.remove('read-mode')

  typeof disqusjs === 'object' && disqusjs.destroy()
})

document.addEventListener('pjax:complete', function () {
  window.refreshFn()

  document.querySelectorAll('script[data-pjax]').forEach(item => {
    const newScript = document.createElement('script')
    const content = item.text || item.textContent || item.innerHTML || ""
    Array.from(item.attributes).forEach(attr => newScript.setAttribute(attr.name, attr.value))
    newScript.appendChild(document.createTextNode(content))
    item.parentNode.replaceChild(newScript, item)
  })

  GLOBAL_CONFIG.islazyload && window.lazyLoadInstance.update()

  typeof chatBtnFn === 'function' && chatBtnFn()
  typeof panguInit === 'function' && panguInit()

  // google analytics
  typeof gtag === 'function' && gtag('config', '', {'page_path': window.location.pathname});

  // baidu analytics
  typeof _hmt === 'object' && _hmt.push(['_trackPageview',window.location.pathname]);

  typeof loadMeting === 'function' && document.getElementsByClassName('aplayer').length && loadMeting()

  // prismjs
  typeof Prism === 'object' && Prism.highlightAll()
})

document.addEventListener('pjax:error', (e) => {
  if (e.request.status === 404) {
    pjax.loadUrl('/404.html')
  }
})</script><script async data-pjax src="//busuanzi.ibruce.info/busuanzi/2.3/busuanzi.pure.mini.js"></script></div></body></html>